{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys, argparse\n",
    "sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname('..'))))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.models import resnet50\n",
    "from thop import profile\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from models import TSN\n",
    "from opts import parser"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_path=\"data/something_train.txt\"\n",
    "val_path=\"data/something_val.txt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name=\"something\"\n",
    "netType1=\"TSM\"\n",
    "netType2=\"MS\"\n",
    "\n",
    "batch_size=1\n",
    "learning_rate=0.01\n",
    "num_segments_8=8\n",
    "num_segments_16=16\n",
    "num_segments_32=32\n",
    "num_segments_128=128\n",
    "mode=1\n",
    "dropout=0.3\n",
    "iter_size=1\n",
    "num_workers=5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "input1 = torch.rand(num_segments_8,3,224,224).cuda()\n",
    "input2 = torch.rand(num_segments_16,3,224,224).cuda()\n",
    "input3 = torch.rand(num_segments_32,3,224,224).cuda()\n",
    "input4 = torch.rand(num_segments_128,3,224,224).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "sys.argv = ['main.py', dataset_name, 'RGB', train_path, val_path, '--arch',\n",
    "            str(netType1), '--num_segments', str(num_segments_8), '--mode', str(mode),\n",
    "            '--gd', '200', '--lr', str(learning_rate), '--lr_steps',\n",
    "            '20', '30', '--epochs', '35', '-b', str(batch_size), '-i',\n",
    "            str(iter_size), '-j', str(num_workers), '--dropout',\n",
    "            str(dropout),\n",
    "            '--consensus_type', 'avg', '--eval-freq', '1', '--rgb_prefix', 'img_',\n",
    "            '--pretrained_parts', 'finetune', '--no_partialbn',\n",
    "            '-p', '20', '--nesterov', 'True']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = parser.parse_args()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "------------------------------------\n",
      "TSM Configurations:\n",
      "- dataset: something\n",
      "- modality: RGB\n",
      "- train_list: data/something_train.txt\n",
      "- val_list: data/something_val.txt\n",
      "- arch: TSM\n",
      "- num_segments: 8\n",
      "- mode: 1\n",
      "- consensus_type: avg\n",
      "- pretrained_parts: finetune\n",
      "- k: 3\n",
      "- dropout: 0.3\n",
      "- loss_type: nll\n",
      "- rep_flow: False\n",
      "- epochs: 35\n",
      "- batch_size: 1\n",
      "- iter_size: 1\n",
      "- lr: 0.01\n",
      "- lr_steps: [20.0, 30.0]\n",
      "- momentum: 0.9\n",
      "- weight_decay: 0.0005\n",
      "- clip_gradient: 200.0\n",
      "- no_partialbn: True\n",
      "- nesterov: True\n",
      "- print_freq: 20\n",
      "- eval_freq: 1\n",
      "- workers: 5\n",
      "- resume: \n",
      "- evaluate: False\n",
      "- snapshot_pref: \n",
      "- val_output_folder: \n",
      "- start_epoch: 0\n",
      "- gpus: None\n",
      "- flow_prefix: img_\n",
      "- rgb_prefix: img_\n",
      "------------------------------------\n"
     ]
    }
   ],
   "source": [
    "args_dict = args.__dict__\n",
    "print(\"------------------------------------\")\n",
    "print(args.arch+\" Configurations:\")\n",
    "for key in args_dict.keys():\n",
    "    print(\"- {}: {}\".format(key, args_dict[key]))\n",
    "print(\"------------------------------------\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "if args.dataset == 'ucf101':\n",
    "    num_class = 101\n",
    "    rgb_read_format = \"{:05d}.jpg\"\n",
    "elif args.dataset == 'hmdb51':\n",
    "    num_class = 51\n",
    "    rgb_read_format = \"{:05d}.jpg\"        \n",
    "elif args.dataset == 'kinetics':\n",
    "    num_class = 400\n",
    "    rgb_read_format = \"{:05d}.jpg\"\n",
    "elif args.dataset == 'something':\n",
    "    num_class = 174\n",
    "    rgb_read_format = \"{:05d}.jpg\"\n",
    "elif args.dataset == 'tinykinetics':\n",
    "    num_class = 150\n",
    "    rgb_read_format = \"{:05d}.jpg\"        \n",
    "else:\n",
    "    raise ValueError('Unknown dataset '+args.dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Initializing TSN with base model: TSM.\n",
      "TSN Configurations:\n",
      "    input_modality:     RGB\n",
      "    num_segments:       8\n",
      "    new_length:         1\n",
      "    consensus_module:   avg\n",
      "    dropout_ratio:      0.3\n",
      "        \n",
      "\n",
      "Initializing TSN with base model: MS.\n",
      "TSN Configurations:\n",
      "    input_modality:     RGB\n",
      "    num_segments:       8\n",
      "    new_length:         1\n",
      "    consensus_module:   avg\n",
      "    dropout_ratio:      0.3\n",
      "        \n"
     ]
    }
   ],
   "source": [
    "TSM_8frame = TSN(num_class, args.num_segments, args.pretrained_parts, args.modality,\n",
    "                base_model=netType1,\n",
    "                consensus_type=args.consensus_type, dropout=args.dropout, partial_bn=not args.no_partialbn).cuda()\n",
    "MS_8frame = TSN(num_class, args.num_segments, args.pretrained_parts, args.modality,\n",
    "                base_model=netType2,\n",
    "                consensus_type=args.consensus_type, dropout=args.dropout, partial_bn=not args.no_partialbn).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "args.num_segments=num_segments_16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Initializing TSN with base model: TSM.\n",
      "TSN Configurations:\n",
      "    input_modality:     RGB\n",
      "    num_segments:       16\n",
      "    new_length:         1\n",
      "    consensus_module:   avg\n",
      "    dropout_ratio:      0.3\n",
      "        \n",
      "\n",
      "Initializing TSN with base model: MS.\n",
      "TSN Configurations:\n",
      "    input_modality:     RGB\n",
      "    num_segments:       16\n",
      "    new_length:         1\n",
      "    consensus_module:   avg\n",
      "    dropout_ratio:      0.3\n",
      "        \n"
     ]
    }
   ],
   "source": [
    "TSM_16frame = TSN(num_class, args.num_segments, args.pretrained_parts, args.modality,\n",
    "                base_model=netType1,\n",
    "                consensus_type=args.consensus_type, dropout=args.dropout, partial_bn=not args.no_partialbn).cuda()\n",
    "MS_16frame = TSN(num_class, args.num_segments, args.pretrained_parts, args.modality,\n",
    "                base_model=netType2,\n",
    "                consensus_type=args.consensus_type, dropout=args.dropout, partial_bn=not args.no_partialbn).cuda()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# temperature\n",
    "temperature = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "No BN layer Freezing.\n",
      "No BN layer Freezing.\n",
      "No BN layer Freezing.\n",
      "No BN layer Freezing.\n",
      "No BN layer Freezing.\n",
      "No BN layer Freezing.\n",
      "No BN layer Freezing.\n",
      "No BN layer Freezing.\n"
     ]
    }
   ],
   "source": [
    "flops1, params1 = profile(TSM_8frame, inputs=(input1, temperature), verbose=False)\n",
    "flops2, params2 = profile(MS_8frame, inputs=(input1, temperature), verbose=False)\n",
    "flops3, params3 = profile(TSM_16frame, inputs=(input2, temperature), verbose=False)\n",
    "flops4, params4 = profile(MS_16frame, inputs=(input2, temperature), verbose=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def human_format(num):\n",
    "    magnitude = 0\n",
    "    while abs(num) >= 1000:\n",
    "        magnitude += 1\n",
    "        num /= 1000.0\n",
    "    # add more suffixes if you need them\n",
    "    return '%.3f%s' % (num, ['', 'K', 'M', 'G', 'T', 'P'][magnitude])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Models\tFrames\tFLOPs\tParams\n",
      "==============================\n",
      "TSM\t8\t14.589G\t11.266M\n",
      "MS\t8\t14.883G\t11.287M\n",
      "TSM\t16\t29.178G\t11.266M\n",
      "MS\t16\t29.788G\t11.287M\n"
     ]
    }
   ],
   "source": [
    "print('Models\\tFrames\\tFLOPs\\tParams')\n",
    "print('='*30)\n",
    "print('%s\\t%d\\t%s\\t%s' % (netType1, num_segments_8, human_format(flops1), human_format(params1)))\n",
    "print('%s\\t%d\\t%s\\t%s' % (netType2, num_segments_8, human_format(flops2), human_format(params2)))\n",
    "print('%s\\t%d\\t%s\\t%s' % (netType1, num_segments_16, human_format(flops3), human_format(params3)))\n",
    "print('%s\\t%d\\t%s\\t%s' % (netType2, num_segments_16, human_format(flops4), human_format(params4)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
